# TODO drop Kherson point
# TODO investigate Unnamed
import pandas as pd
import seaborn as sns
SEED = 14
data = pd.read_parquet("cleaned_data/data.parquet")
print(data.info())
<class 'pandas.core.frame.DataFrame'> Int64Index: 6010 entries, 0 to 6010 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 hotspot_id 6010 non-null int64 1 blacklist_score 6010 non-null float64 2 static_score 6010 non-null float64 3 dynamic_score 6010 non-null float64 4 connection_stats_score 6010 non-null float64 5 last_conn_date 6010 non-null datetime64[ns] 6 last_seen_date 6010 non-null datetime64[ns] 7 num_conn 6010 non-null int64 8 unique_conn 6010 non-null int64 9 percent_available 6010 non-null float64 10 percent_protected 6010 non-null float64 11 enabled_moderator 6010 non-null bool dtypes: bool(1), datetime64[ns](2), float64(6), int64(3) memory usage: 569.3 KB None
data
| hotspot_id | blacklist_score | static_score | dynamic_score | connection_stats_score | last_conn_date | last_seen_date | num_conn | unique_conn | percent_available | percent_protected | enabled_moderator | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 14650480 | 0.0 | 0.22 | 0.45 | 0.69 | 2022-08-21 | 2021-05-01 | 5 | 2 | 1.0 | 1.0 | True |
| 1 | 14110275 | 0.0 | 0.22 | 0.00 | 0.67 | 2022-02-04 | 2022-02-04 | 4 | 2 | 1.0 | 1.0 | True |
| 2 | 16012785 | 0.0 | 0.18 | 0.16 | 0.67 | 2022-02-05 | 2022-02-15 | 4 | 2 | 1.0 | 1.0 | True |
| 3 | 14863945 | 0.0 | 0.22 | 0.05 | 0.72 | 2022-03-15 | 2021-06-12 | 6 | 2 | 1.0 | 0.8 | True |
| 4 | 9295867 | 0.0 | 0.39 | 0.00 | 0.52 | 2017-12-07 | 2017-12-07 | 1 | 1 | 1.0 | 0.0 | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 6006 | 13213372 | 0.0 | 0.25 | 0.00 | 0.52 | 2020-04-27 | 2020-04-27 | 1 | 1 | 1.0 | 0.0 | True |
| 6007 | 5504114 | 0.0 | 0.22 | 0.00 | 0.52 | 2016-04-10 | 2022-01-02 | 1 | 1 | 1.0 | 0.0 | True |
| 6008 | 15109612 | 0.0 | 0.48 | 0.00 | 0.52 | 2021-07-20 | 2021-07-20 | 1 | 1 | 1.0 | 0.0 | True |
| 6009 | 502326 | 0.0 | 0.12 | 0.00 | 0.52 | 2013-11-19 | 2013-11-19 | 1 | 1 | 1.0 | 0.0 | True |
| 6010 | 14242378 | 0.0 | 0.22 | 0.00 | 0.52 | 2021-01-29 | 2021-01-29 | 1 | 1 | 1.0 | 0.0 | True |
6010 rows × 12 columns
data.nunique()
hotspot_id 6003 blacklist_score 2 static_score 79 dynamic_score 90 connection_stats_score 22 last_conn_date 2243 last_seen_date 2059 num_conn 615 unique_conn 57 percent_available 2 percent_protected 57 enabled_moderator 1 dtype: int64
# small_variaty_columns = hotspots_nunique[hotspots_nunique < 50].index
# for column in small_variaty_columns:
# print(column)
# print(hotspots[column].value_counts(dropna = False))
# print()
SORTED_QUALITY = ["spam", "bad", "moderate", "good"]
def calculate_quality(scores):
def calculate_quality_for_row(row):
blacklist_score, dynamic_score = row["blacklist_score"], row["dynamic_score"]
if blacklist_score == 1:
return "spam"
if dynamic_score < 0.3:
return "bad"
elif dynamic_score >= 0.3 and dynamic_score < 0.6:
return "moderate"
return "good"
quality = scores.apply(calculate_quality_for_row, axis = 1)\
.astype("category").cat.reorder_categories(SORTED_QUALITY)
return quality
def calculate_quality_code(scores):
quality = calculate_quality(scores)
return quality.cat.codes.rename("quality_cat_id")
print(calculate_quality(data).value_counts())
print()
bad 3787 good 1177 moderate 949 spam 97 dtype: int64
scores = data.copy()
scores["quality_cat_id"] = calculate_quality_code(scores)
print()
print(scores["quality_cat_id"].value_counts())
1 3787 3 1177 2 949 0 97 Name: quality_cat_id, dtype: int64
default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = {
0: "black",
1: "red",
2: "blue",
3: "green",
}, height = 2)
default_pairplot
<seaborn.axisgrid.PairGrid at 0x7f98bb406a40>
sns.pairplot(scores, hue = "quality_cat_id", kind="hist", height=1.5)
<seaborn.axisgrid.PairGrid at 0x7f0d6e7b9510>
sns.pairplot(scores, kind="hist", height=1.5)
<seaborn.axisgrid.PairGrid at 0x7f0d6999da50>
sns.pairplot(scores, hue = "quality_cat_id", palette = {
0: "black",
1: "red",
2: "blue",
3: "green",
}, height=1.5)
import lightgbm as lgbm
data
| hotspot_id | blacklist_score | static_score | dynamic_score | connection_stats_score | last_conn_date | last_seen_date | num_conn | unique_conn | percent_available | percent_protected | enabled_moderator | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 14650480 | 0.0 | 0.22 | 0.45 | 0.69 | 2022-08-21 | 2021-05-01 | 5 | 2 | 1.0 | 1.0 | True |
| 1 | 14110275 | 0.0 | 0.22 | 0.00 | 0.67 | 2022-02-04 | 2022-02-04 | 4 | 2 | 1.0 | 1.0 | True |
| 2 | 16012785 | 0.0 | 0.18 | 0.16 | 0.67 | 2022-02-05 | 2022-02-15 | 4 | 2 | 1.0 | 1.0 | True |
| 3 | 14863945 | 0.0 | 0.22 | 0.05 | 0.72 | 2022-03-15 | 2021-06-12 | 6 | 2 | 1.0 | 0.8 | True |
| 4 | 9295867 | 0.0 | 0.39 | 0.00 | 0.52 | 2017-12-07 | 2017-12-07 | 1 | 1 | 1.0 | 0.0 | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 6006 | 13213372 | 0.0 | 0.25 | 0.00 | 0.52 | 2020-04-27 | 2020-04-27 | 1 | 1 | 1.0 | 0.0 | True |
| 6007 | 5504114 | 0.0 | 0.22 | 0.00 | 0.52 | 2016-04-10 | 2022-01-02 | 1 | 1 | 1.0 | 0.0 | True |
| 6008 | 15109612 | 0.0 | 0.48 | 0.00 | 0.52 | 2021-07-20 | 2021-07-20 | 1 | 1 | 1.0 | 0.0 | True |
| 6009 | 502326 | 0.0 | 0.12 | 0.00 | 0.52 | 2013-11-19 | 2013-11-19 | 1 | 1 | 1.0 | 0.0 | True |
| 6010 | 14242378 | 0.0 | 0.22 | 0.00 | 0.52 | 2021-01-29 | 2021-01-29 | 1 | 1 | 1.0 | 0.0 | True |
6010 rows × 12 columns
y = calculate_quality_code(data)
print(y.info())
print()
insight_columns = ["blacklist_score", "hotspot_id", "dynamic_score"]
X = data.drop(columns = insight_columns)
print(X.info())
<class 'pandas.core.series.Series'> Int64Index: 6010 entries, 0 to 6010 Series name: quality_cat_id Non-Null Count Dtype -------------- ----- 6010 non-null int8 dtypes: int8(1) memory usage: 181.9 KB None <class 'pandas.core.frame.DataFrame'> Int64Index: 6010 entries, 0 to 6010 Data columns (total 9 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 static_score 6010 non-null float64 1 connection_stats_score 6010 non-null float64 2 last_conn_date 6010 non-null datetime64[ns] 3 last_seen_date 6010 non-null datetime64[ns] 4 num_conn 6010 non-null int64 5 unique_conn 6010 non-null int64 6 percent_available 6010 non-null float64 7 percent_protected 6010 non-null float64 8 enabled_moderator 6010 non-null bool dtypes: bool(1), datetime64[ns](2), float64(4), int64(2) memory usage: 557.5 KB None
# pd.concat([data.loc[y_train.index], y_train], axis = 1)
# X_y_pairplot = sns.pairplot(pd.concat([X, y], axis = 1), hue = y.name, palette = {
# 0: "black",
# 1: "red",
# 2: "blue",
# 3: "green",
# }, height = 2)
# X_y_pairplot
# X_y_pairplot_2 = sns.pairplot(pd.concat([X, y], axis = 1), height = 2)
# X_y_pairplot_2
import pandas as pd
import numpy as np
today = pd.Timestamp.now()
from lib import calculate_days_passed
X["last_conn_days"] = calculate_days_passed(X["last_conn_date"], today)
X["last_seen_days"] = calculate_days_passed(X["last_seen_date"], today)
X.drop(columns = ["last_conn_date", "last_seen_date"], inplace = True)
X.info()
<class 'pandas.core.frame.DataFrame'> Int64Index: 6010 entries, 0 to 6010 Data columns (total 9 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 static_score 6010 non-null float64 1 connection_stats_score 6010 non-null float64 2 num_conn 6010 non-null int64 3 unique_conn 6010 non-null int64 4 percent_available 6010 non-null float64 5 percent_protected 6010 non-null float64 6 enabled_moderator 6010 non-null bool 7 last_conn_days 6010 non-null int16 8 last_seen_days 6010 non-null int16 dtypes: bool(1), float64(4), int16(2), int64(2) memory usage: 487.1 KB
# default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = {
# 0: "black",
# 1: "red",
# 2: "blue",
# 3: "green",
# }, height = 2)
# default_pairplot
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Input In [21], in <cell line: 1>() ----> 1 default_pairplot = sns.pairplot(scores, hue = "quality_cat_id", palette = { 2 0: "black", 3 1: "red", 4 2: "blue", 5 3: "green", 6 }, height = 2) 7 default_pairplot File ~/.local/lib/python3.10/site-packages/seaborn/_decorators.py:46, in _deprecate_positional_args.<locals>.inner_f(*args, **kwargs) 36 warnings.warn( 37 "Pass the following variable{} as {}keyword arg{}: {}. " 38 "From version 0.12, the only valid positional argument " (...) 43 FutureWarning 44 ) 45 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) ---> 46 return f(**kwargs) File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:2140, in pairplot(data, hue, hue_order, palette, vars, x_vars, y_vars, kind, diag_kind, markers, height, aspect, corner, dropna, plot_kws, diag_kws, grid_kws, size) 2138 if kind == "scatter": 2139 from .relational import scatterplot # Avoid circular import -> 2140 plotter(scatterplot, **plot_kws) 2141 elif kind == "reg": 2142 from .regression import regplot # Avoid circular import File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1387, in PairGrid.map_offdiag(self, func, **kwargs) 1376 """Plot with a bivariate function on the off-diagonal subplots. 1377 1378 Parameters (...) 1384 1385 """ 1386 if self.square_grid: -> 1387 self.map_lower(func, **kwargs) 1388 if not self._corner: 1389 self.map_upper(func, **kwargs) File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1357, in PairGrid.map_lower(self, func, **kwargs) 1346 """Plot with a bivariate function on the lower diagonal subplots. 1347 1348 Parameters (...) 1354 1355 """ 1356 indices = zip(*np.tril_indices_from(self.axes, -1)) -> 1357 self._map_bivariate(func, indices, **kwargs) 1358 return self File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1539, in PairGrid._map_bivariate(self, func, indices, **kwargs) 1537 if ax is None: # i.e. we are in corner mode 1538 continue -> 1539 self._plot_bivariate(x_var, y_var, ax, func, **kws) 1540 self._add_axis_labels() 1542 if "hue" in signature(func).parameters: File ~/.local/lib/python3.10/site-packages/seaborn/axisgrid.py:1579, in PairGrid._plot_bivariate(self, x_var, y_var, ax, func, **kwargs) 1577 kwargs.setdefault("hue_order", self._hue_order) 1578 kwargs.setdefault("palette", self._orig_palette) -> 1579 func(x=x, y=y, **kwargs) 1581 self._update_legend_data(ax) File ~/.local/lib/python3.10/site-packages/seaborn/_decorators.py:46, in _deprecate_positional_args.<locals>.inner_f(*args, **kwargs) 36 warnings.warn( 37 "Pass the following variable{} as {}keyword arg{}: {}. " 38 "From version 0.12, the only valid positional argument " (...) 43 FutureWarning 44 ) 45 kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) ---> 46 return f(**kwargs) File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:827, in scatterplot(x, y, hue, style, size, data, palette, hue_order, hue_norm, sizes, size_order, size_norm, markers, style_order, x_bins, y_bins, units, estimator, ci, n_boot, alpha, x_jitter, y_jitter, legend, ax, **kwargs) 823 return ax 825 p._attach(ax) --> 827 p.plot(ax, kwargs) 829 return ax File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:670, in _ScatterPlotter.plot(self, ax, kws) 668 self._add_axis_labels(ax) 669 if self.legend: --> 670 self.add_legend_data(ax) 671 handles, _ = ax.get_legend_handles_labels() 672 if handles: File ~/.local/lib/python3.10/site-packages/seaborn/relational.py:337, in _RelationalPlotter.add_legend_data(self, ax) 335 if attr in kws: 336 use_kws[attr] = kws[attr] --> 337 artist = func([], [], label=label, **use_kws) 338 if self._legend_func == "plot": 339 artist = artist[0] File ~/.local/lib/python3.10/site-packages/matplotlib/__init__.py:1601, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs) 1598 @functools.wraps(func) 1599 def inner(ax, *args, data=None, **kwargs): 1600 if data is None: -> 1601 return func(ax, *map(sanitize_sequence, args), **kwargs) 1603 bound = new_sig.bind(ax, *args, **kwargs) 1604 needs_label = (label_namer 1605 and "label" not in bound.arguments 1606 and "label" not in bound.kwargs) File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_axes.py:4528, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, plotnonfinite, **kwargs) 4525 self.set_ymargin(0.05) 4527 self.add_collection(collection) -> 4528 self.autoscale_view() 4530 return collection File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_base.py:2496, in _AxesBase.autoscale_view(self, tight, scalex, scaley) 2491 # End of definition of internal function 'handle_single_axis'. 2493 handle_single_axis( 2494 scalex, self._autoscaleXon, self._shared_x_axes, 'intervalx', 2495 'minposx', self.xaxis, self._xmargin, x_stickies, self.set_xbound) -> 2496 handle_single_axis( 2497 scaley, self._autoscaleYon, self._shared_y_axes, 'intervaly', 2498 'minposy', self.yaxis, self._ymargin, y_stickies, self.set_ybound) File ~/.local/lib/python3.10/site-packages/matplotlib/axes/_base.py:2449, in _AxesBase.autoscale_view.<locals>.handle_single_axis(scale, autoscaleon, shared_axes, interval, minpos, axis, margin, stickies, set_bound) 2446 dl.extend(x_finite) 2447 dl.extend(y_finite) -> 2449 bb = mtransforms.BboxBase.union(dl) 2450 # fall back on the viewlimits if this is not finite: 2451 vl = None File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:703, in BboxBase.union(bboxes) 701 raise ValueError("'bboxes' cannot be empty") 702 x0 = np.min([bbox.xmin for bbox in bboxes]) --> 703 x1 = np.max([bbox.xmax for bbox in bboxes]) 704 y0 = np.min([bbox.ymin for bbox in bboxes]) 705 y1 = np.max([bbox.ymax for bbox in bboxes]) File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:703, in <listcomp>(.0) 701 raise ValueError("'bboxes' cannot be empty") 702 x0 = np.min([bbox.xmin for bbox in bboxes]) --> 703 x1 = np.max([bbox.xmax for bbox in bboxes]) 704 y0 = np.min([bbox.ymin for bbox in bboxes]) 705 y1 = np.max([bbox.ymax for bbox in bboxes]) File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:360, in BboxBase.xmax(self) 357 @property 358 def xmax(self): 359 """The right edge of the bounding box.""" --> 360 return np.max(self.get_points()[:, 0]) File <__array_function__ internals>:180, in amax(*args, **kwargs) File ~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:2791, in amax(a, axis, out, keepdims, initial, where) 2675 @array_function_dispatch(_amax_dispatcher) 2676 def amax(a, axis=None, out=None, keepdims=np._NoValue, initial=np._NoValue, 2677 where=np._NoValue): 2678 """ 2679 Return the maximum of an array or maximum along an axis. 2680 (...) 2789 5 2790 """ -> 2791 return _wrapreduction(a, np.maximum, 'max', axis, None, out, 2792 keepdims=keepdims, initial=initial, where=where) File ~/.local/lib/python3.10/site-packages/numpy/core/fromnumeric.py:73, in _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs) 69 def _wrapreduction(obj, ufunc, method, axis, dtype, out, **kwargs): 70 passkwargs = {k: v for k, v in kwargs.items() 71 if v is not np._NoValue} ---> 73 if type(obj) is not mu.ndarray: 74 try: 75 reduction = getattr(obj, method) KeyboardInterrupt:
Error in callback <function flush_figures at 0x7f98bdeb4820> (for post_execute):
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:121, in flush_figures() 118 if InlineBackend.instance().close_figures: 119 # ignore the tracking, just draw and close all figures 120 try: --> 121 return show(True) 122 except Exception as e: 123 # safely show traceback if in IPython, else raise 124 ip = get_ipython() File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:43, in show(close, block) 39 try: 40 for figure_manager in Gcf.get_all_fig_managers(): 41 display( 42 figure_manager.canvas.figure, ---> 43 metadata=_fetch_figure_metadata(figure_manager.canvas.figure) 44 ) 45 finally: 46 show._to_draw = [] File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:231, in _fetch_figure_metadata(fig) 228 # determine if a background is needed for legibility 229 if _is_transparent(fig.get_facecolor()): 230 # the background is transparent --> 231 ticksLight = _is_light([label.get_color() 232 for axes in fig.axes 233 for axis in (axes.xaxis, axes.yaxis) 234 for label in axis.get_ticklabels()]) 235 if ticksLight.size and (ticksLight == ticksLight[0]).all(): 236 # there are one or more tick labels, all with the same lightness 237 return {'needs_background': 'dark' if ticksLight[0] else 'light'} File ~/.local/lib/python3.10/site-packages/matplotlib_inline/backend_inline.py:234, in <listcomp>(.0) 228 # determine if a background is needed for legibility 229 if _is_transparent(fig.get_facecolor()): 230 # the background is transparent 231 ticksLight = _is_light([label.get_color() 232 for axes in fig.axes 233 for axis in (axes.xaxis, axes.yaxis) --> 234 for label in axis.get_ticklabels()]) 235 if ticksLight.size and (ticksLight == ticksLight[0]).all(): 236 # there are one or more tick labels, all with the same lightness 237 return {'needs_background': 'dark' if ticksLight[0] else 'light'} File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1296, in Axis.get_ticklabels(self, minor, which) 1294 if minor: 1295 return self.get_minorticklabels() -> 1296 return self.get_majorticklabels() File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1252, in Axis.get_majorticklabels(self) 1250 def get_majorticklabels(self): 1251 'Return a list of Text instances for the major ticklabels.' -> 1252 ticks = self.get_major_ticks() 1253 labels1 = [tick.label1 for tick in ticks if tick.label1.get_visible()] 1254 labels2 = [tick.label2 for tick in ticks if tick.label2.get_visible()] File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1407, in Axis.get_major_ticks(self, numticks) 1405 'Get the tick instances; grow as necessary.' 1406 if numticks is None: -> 1407 numticks = len(self.get_majorticklocs()) 1409 while len(self.majorTicks) < numticks: 1410 # Update the new tick label properties from the old. 1411 tick = self._get_tick(major=True) File ~/.local/lib/python3.10/site-packages/matplotlib/axis.py:1324, in Axis.get_majorticklocs(self) 1322 def get_majorticklocs(self): 1323 """Get the array of major tick locations in data coordinates.""" -> 1324 return self.major.locator() File ~/.local/lib/python3.10/site-packages/matplotlib/ticker.py:2078, in MaxNLocator.__call__(self) 2076 def __call__(self): 2077 vmin, vmax = self.axis.get_view_interval() -> 2078 return self.tick_values(vmin, vmax) File ~/.local/lib/python3.10/site-packages/matplotlib/ticker.py:2084, in MaxNLocator.tick_values(self, vmin, vmax) 2082 vmax = max(abs(vmin), abs(vmax)) 2083 vmin = -vmax -> 2084 vmin, vmax = mtransforms.nonsingular( 2085 vmin, vmax, expander=1e-13, tiny=1e-14) 2086 locs = self._raw_ticks(vmin, vmax) 2088 prune = self._prune File ~/.local/lib/python3.10/site-packages/matplotlib/transforms.py:2828, in nonsingular(vmin, vmax, expander, tiny, increasing) 2825 swapped = True 2827 maxabsvalue = max(abs(vmin), abs(vmax)) -> 2828 if maxabsvalue < (1e6 / tiny) * np.finfo(float).tiny: 2829 vmin = -expander 2830 vmax = expander File ~/.local/lib/python3.10/site-packages/numpy/core/getlimits.py:577, in finfo.tiny(self) 562 @property 563 def tiny(self): 564 """Return the value for tiny, alias of smallest_normal. 565 566 Returns (...) 575 double-double. 576 """ --> 577 return self.smallest_normal File ~/.local/lib/python3.10/site-packages/numpy/core/getlimits.py:556, in finfo.smallest_normal(self) 541 """Return the value for the smallest normal. 542 543 Returns (...) 552 double-double. 553 """ 554 # This check is necessary because the value for smallest_normal is 555 # platform dependent for longdouble types. --> 556 if isnan(self._machar.smallest_normal.flat[0]): 557 warnings.warn( 558 'The value of smallest normal is undefined for double double', 559 UserWarning, stacklevel=2) 560 return self._machar.smallest_normal.flat[0] KeyboardInterrupt:
X_for_exp = X #.drop(columns = ["connection_stats_score", "percent_protected", "last_conn_days", "static_score"])
X_for_exp = X[["connection_stats_score", "last_seen_days", "last_conn_days"]]
y_for_exp = y.loc[X.index]
X_for_exp["last_seen_days"].hist()
<matplotlib.axes._subplots.AxesSubplot at 0x7f98100540a0>
from typing import List
import plotly.graph_objs as go
from sklearn.metrics import confusion_matrix
def confusion_matrix_plot(y_test: List[str], y_pred: List[str], labels: List[str], display_labels: List[str] = None, normalise: bool = False) -> go.Figure:
# Compute the confusion matrix
cm = confusion_matrix(y_test, y_pred, labels=labels)
# Normalize the matrix by rows
if normalise:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# Define the display labels
if display_labels is None:
display_labels = labels
# Define the data for the heatmap
data = go.Heatmap(
z=cm,
x=display_labels,
y=display_labels,
colorscale='YlGnBu'
)
# Define the layout of the plot
layout = go.Layout(
title='Confusion Matrix',
xaxis=dict(title='Predicted label'),
yaxis=dict(title='True label')
)
# Create the plot
fig = go.Figure(data=[data], layout=layout)
return fig
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
# Load your data and split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_for_exp, y_for_exp, test_size=0.2, random_state=SEED)
# Create an instance of the LGBMClassifier
params = {
'num_classes': 4,
'objective': 'multiclass',
'random_state': SEED,
'n_estimators': 100,
'max_depth': 10,
'learning_rate': 0.1
}
lgbm = LGBMClassifier(**params)
# Fit the model on the training data
lgbm.fit(X_train, y_train)
# Predict the class labels of the testing data
y_pred = lgbm.predict(X_test)
print("Train classes: ")
print(y_train.value_counts())
print("Test classes: ")
print(y_test.value_counts())
print("Test predicts: ")
print( pd.DataFrame(y_pred).value_counts() )
# Print the classification report
print(classification_report(y_test, y_pred))
correlations = pd.concat([X_for_exp, y_for_exp], axis=1).corr()
# Print the correlation between each feature and the target variable
print("Correlation to quality")
print(correlations[y.name].iloc[:-1])
cm = confusion_matrix_plot(y_test, y_pred, lgbm.classes_, display_labels = SORTED_QUALITY, normalise = True)
cm.show()
cm = confusion_matrix_plot(y_test, y_pred, lgbm.classes_, display_labels = SORTED_QUALITY, normalise = False)
cm.show()
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000113 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 529
[LightGBM] [Info] Number of data points in the train set: 4808, number of used features: 3
[LightGBM] [Info] Start training from score -4.096010
[LightGBM] [Info] Start training from score -0.461389
[LightGBM] [Info] Start training from score -1.838161
[LightGBM] [Info] Start training from score -1.640704
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] Accuracy may be bad since you didn't explicitly set num_leaves OR 2^max_depth > num_leaves. (num_leaves=31).
Train classes:
1 3031
3 932
2 765
0 80
Name: quality_cat_id, dtype: int64
Test classes:
1 756
3 245
2 184
0 17
Name: quality_cat_id, dtype: int64
Test predicts:
1 766
3 225
2 209
0 2
dtype: int64
precision recall f1-score support
0 0.00 0.00 0.00 17
1 0.96 0.97 0.96 756
2 0.76 0.86 0.81 184
3 0.96 0.88 0.92 245
accuracy 0.92 1202
macro avg 0.67 0.68 0.67 1202
weighted avg 0.91 0.92 0.92 1202
Correlation to quality
connection_stats_score 0.859035
last_seen_days -0.585585
last_conn_days -0.620113
Name: quality_cat_id, dtype: float64
# Get feature importances and column names
importances = lgbm.feature_importances_
features = X_for_exp.columns
# Create a list of tuples of feature names and importances, sorted by importance
feature_importances = [(feature, importance) for feature, importance in zip(features, importances)]
feature_importances = sorted(feature_importances, key=lambda x: x[1], reverse=True)
# Print the sorted list of features and their importances
for feature, importance in feature_importances:
print('{}: {}'.format(feature, importance))
last_seen_days: 5287 last_conn_days: 5155 connection_stats_score: 1540
import pickle
# Save the model to a file using pickle
with open('lgbm_model.pkl', 'wb') as f:
pickle.dump(lgbm, f)